{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pretraining on WT103"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_12a import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One time download"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=7410)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#path = datasets.Config().data_path()\n",
    "#version = '103' #2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#! wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-{version}-v1.zip -P {path}\n",
    "#! unzip -q -n {path}/wikitext-{version}-v1.zip  -d {path}\n",
    "#! mv {path}/wikitext-{version}/wiki.train.tokens {path}/wikitext-{version}/train.txt\n",
    "#! mv {path}/wikitext-{version}/wiki.valid.tokens {path}/wikitext-{version}/valid.txt\n",
    "#! mv {path}/wikitext-{version}/wiki.test.tokens {path}/wikitext-{version}/test.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the articles: WT103 is given as one big text file and we need to chunk it in different articles if we want to be able to shuffle them at the beginning of each epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = datasets.Config().data_path()/'wikitext-103'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def istitle(line):\n",
    "    return len(re.findall(r'^ = [^=]* = $', line)) != 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_wiki(filename):\n",
    "    articles = []\n",
    "    with open(filename, encoding='utf8') as f:\n",
    "        lines = f.readlines()\n",
    "    current_article = ''\n",
    "    for i,line in enumerate(lines):\n",
    "        current_article += line\n",
    "        if i < len(lines)-2 and lines[i+1] == ' \\n' and istitle(lines[i+2]):\n",
    "            current_article = current_article.replace('<unk>', UNK)\n",
    "            articles.append(current_article)\n",
    "            current_article = ''\n",
    "    current_article = current_article.replace('<unk>', UNK)\n",
    "    articles.append(current_article)\n",
    "    return articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = TextList(read_wiki(path/'train.txt'), path=path) #+read_file(path/'test.txt')\n",
    "valid = TextList(read_wiki(path/'valid.txt'), path=path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train), len(valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sd = SplitData(train, valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proc_tok,proc_num = TokenizeProcessor(),NumericalizeProcessor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ll = label_by_func(sd, lambda x: 0, proc_x = [proc_tok,proc_num])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(ll, open(path/'ld.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ll = pickle.load( open(path/'ld.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs,bptt = 128,70\n",
    "data = lm_databunchify(ll, bs, bptt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = ll.train.proc_x[-1].vocab\n",
    "len(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dps = np.array([0.1, 0.15, 0.25, 0.02, 0.2]) * 0.2\n",
    "tok_pad = vocab.index(PAD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_sz, nh, nl = 300, 300, 2\n",
    "model = get_language_model(len(vocab), emb_sz, nh, nl, tok_pad, *dps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbs = [partial(AvgStatsCallback,accuracy_flat),\n",
    "       CudaCallback, Recorder,\n",
    "       partial(GradientClipping, clip=0.1),\n",
    "       partial(RNNTrainer, α=2., β=1.),\n",
    "       ProgressCallback]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(model, data, cross_entropy_flat, lr=5e-3, cb_funcs=cbs, opt_func=adam_opt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 5e-3\n",
    "sched_lr  = combine_scheds([0.3,0.7], cos_1cycle_anneal(lr/10., lr, lr/1e5))\n",
    "sched_mom = combine_scheds([0.3,0.7], cos_1cycle_anneal(0.8, 0.7, 0.8))\n",
    "cbsched = [ParamScheduler('lr', sched_lr), ParamScheduler('mom', sched_mom)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit(10, cbs=cbsched)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(learn.model.state_dict(), path/'pretrained.pth')\n",
    "pickle.dump(vocab, open(path/'vocab.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
